package cn.qu.dockman.mgr.ws;

import cn.qu.dockman.mgr.entity.Host;
import cn.qu.dockman.mgr.service.HostService;
import cn.qu.dockman.protocol.command.Command;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import javax.security.auth.DestroyFailedException;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Created by zh on 17/3/9.
 */
public class SendingQueue implements Runnable, InitializingBean, DisposableBean {
    private static final Logger logger = LoggerFactory.getLogger(SendingQueue.class);

    private ArrayBlockingQueue<Message> queue = new ArrayBlockingQueue<>(65535);

    private AtomicBoolean closing = new AtomicBoolean(false);

    private AtomicInteger count = new AtomicInteger(0);

    private Map<Integer, WebSocketSession> sessionMap = new ConcurrentHashMap<>();

    private HostService hostService;

    public void newSession(WebSocketSession session) {
        String remoteIp = session.getRemoteAddress().getAddress().getHostAddress();
        Host host = hostService.findByIp(remoteIp);
        if (host != null && !sessionMap.containsKey(host.getId())) {
            sessionMap.put(host.getId(), session);
        }
    }

    public void removeSession(WebSocketSession session) {
        String remoteIp = session.getRemoteAddress().getAddress().getHostAddress();
        Host host = hostService.findByIp(remoteIp);

        if (host != null && sessionMap.containsKey(host.getId())) {
            sessionMap.remove(host.getId());
        }
    }

    @Override
    public void run() {
        logger.info("开始消费队列");
        while (true) {
            if (count.get() == 0 && closing.get()) {
                logger.info("队列已消费完毕，销毁中。。。");
                Thread.currentThread().interrupt();
                break;
            } else {
                Message message = queue.poll();
                if (message != null) {
                    WebSocketSession session = this.sessionMap.get(message.getHostId());
                    if (session == null) {
                        logger.info("未找到连接到节点机的session,命令被丢弃: {}", message.getCmd().toString());
                    } else {
                        try {
                            String cmdString = message.getCmd().toString();
                            session.sendMessage(new TextMessage(cmdString));
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }

    }

    public <T extends Command> void send(Integer hostId, T cmd) {
        if (!closing.get()) {
            queue.add(new Message(hostId, cmd));
            count.addAndGet(1);
        }
    }

    @Override
    public void destroy() throws DestroyFailedException {
        closing.set(true);
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        new Thread(this).start();
    }

    public void setHostService(HostService hostService) {
        this.hostService = hostService;
    }

    class Message {
        private Integer hostId;
        private Command cmd;

        public Message(Integer hostId, Command cmd) {
            this.hostId = hostId;
            this.cmd = cmd;
        }

        public Integer getHostId() {
            return hostId;
        }

        public Command getCmd() {
            return cmd;
        }
    }
}
